Kerry Back
Prediction in each cell is the plurality class (for classification) or the cell mean (for regression).
Get data from the SQL database as before
from sklearn.tree import DecisionTreeClassifier data['class'] = data.ret.transform( lambda x: pd.qcut(x, 3, labels=(0, 1, 2)) ) X = data[["roeq", "mom12m"]] y = data["class"] model = DecisionTreeClassifier( max_depth=2, random_state=0 ) model.fit(X, y)
from sklearn.tree import plot_tree import matplotlib.pyplot as plt plot_tree(model) plt.show()
from sklearn.metrics import ConfusionMatrixDisplay ConfusionMatrixDisplay.from_estimator(model, X=X, y=y) plt.show()
from sklearn.tree import DecisionTreeRegressor X = data[["roeq", "mom12m"]] y = data["ret"] model = DecisionTreeRegressor( max_depth=2, random_state=0 ) model.fit(X, y)
plot_tree(model) plt.show()
data['rnk'] = data.ret.rank(pct=True) X = data[["roeq", "mom12m"]] y = data["rnk"] model = DecisionTreeRegressor( max_depth=2, random_state=0 ) model.fit(X, y)
X = data[["roeq", "mom12m"]] y = data["class"] model = DecisionTreeRegressor( max_depth=2, random_state=0 ) model.fit(X, y)